Skip to content

[Kernel] Add FlashInfer MoE A2A Kernel#36022

Merged
ywang96 merged 13 commits intovllm-project:mainfrom
CentML:wzhao/moe-a2a
Mar 16, 2026
Merged

[Kernel] Add FlashInfer MoE A2A Kernel#36022
ywang96 merged 13 commits intovllm-project:mainfrom
CentML:wzhao/moe-a2a

Conversation

@leo-cf-tian
Copy link
Contributor

@leo-cf-tian leo-cf-tian commented Mar 4, 2026

Purpose

This PR is a port of PR #32217 to the vLLM top-of-tree after the modular kernel refactors in #32564. It adds the latest TRT-LLM gen A2A kernel from flashinfer's MoE-A2A API (one sided all-to-all) as added in (flashinfer-ai/flashinfer#2102). This should perform better than the older A2A kernel from #21003 (formerly flashinfer_all2allv) in large batch size.

The new kernel can be enabled by specifying --all2all-backend flashinfer_nvlink_one_sided. It is only available for nvfp4.

This PR also renames flashinfer_all2allv to flashinfer_nvlink_two_sided as per suggestion as it is more descriptive and matches the new implementation.

We conducted benchmarks and found a noticeable increase in throughput at high concurrency, up to a 14% increase in throughput at 512 concurrency.

image

Testing

The PR also adds test coverage from @stecasta.

  • Register FlashInferMoeA2APrepareAndFinalize in the modular kernel combinatorial test framework (mk_objects.py), enabling automatic multi-GPU testing against all compatible Expert backends with nvfp4 quantization
  • Register TrtLlmNvFp4ExpertsModular in the same framework (previously missing from the test registry)
  • Add parametrized tests validating the _supports_parallel_config incompatibility matrix for the new flashinfer_moe_a2a backend across 7 Expert types
  • Add a parity test ensuring flashinfer_moe_a2a and flashinfer_all2allv share the same incompatibility matrix, catching drift if one is updated without the other

Test plan

  • test_supports_parallel_config_flashinfer_moe_a2a — CPU only, 7 parametrized cases
  • test_supports_parallel_config_parity_with_all2allv — CPU only, 7 parametrized cases
  • Combinatorial coverage via test_modular_kernel_combinations_multigpu — multi-GPU, auto-generated from mk_objects.py registrations
  • Future: dedicated multi-GPU test with broader shape coverage once the A2A manager supports standalone initialization

Notes

The incompatibility matrix tests do not require a GPU and can run in any CI environment. The combinatorial multi-GPU tests require 2x Blackwell GPUs with FlashInfer trtllm_moe_alltoall support.

Reproduction

To reproduce our results, the server can be launched with the following configuration:

vllm serve nvidia/DeepSeek-R1-NVFP4 \
    --trust-remote-code \
    --max-num-seqs 1024 \
    --max-num-batched-tokens 2048 \
    --stream-interval 20 \
    --no-enable-prefix-caching \
    --kv-cache-dtype fp8 \
    --max-cudagraph-capture-size 2048 \
    --data-parallel-size 8 \
    --tensor-parallel-size 1 \
    --pipeline-parallel-size 1 \
    --enable-expert-parallel \
    --gpu-memory-utilization 0.8 \
    --all2all-backend (None / flashinfer_all2allv / flashinfer_moe_a2a)

To verify correctness, you can run gsm8k:

lm_eval --model local-chat-completions --tasks gsm8k --model_args base_url=http://0.0.0.0:8000/v1/chat/completions,max_gen_toks=16384,num_concurrent=64 --batch_size auto --fewshot_as_multiturn --apply_chat_template

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the FlashInfer MoE A2A kernel, which is a welcome addition for improving performance in large batch size scenarios. The integration of the new kernel is well-executed across the codebase, including configuration, communicator management, and kernel selection logic. I've identified one high-severity issue related to determining the number of GPUs per node, which could lead to suboptimal performance. My detailed feedback and a suggested fix are in the review comment.

Signed-off-by: Leo Tian <lctian@nvidia.com>
@mergify
Copy link

mergify bot commented Mar 12, 2026

Hi @leo-cf-tian, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@mergify
Copy link

mergify bot commented Mar 13, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @leo-cf-tian.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 13, 2026
@elvircrn
Copy link
Contributor

elvircrn commented Mar 13, 2026

@leo-cf-tian re-running with your latest commit and without my sed.

Copy link
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me, assuming we see correctness and are past the issue @elvircrn was running into

@github-project-automation github-project-automation bot moved this from To Triage to Ready in gpt-oss Issues & Enhancements Mar 14, 2026
@github-project-automation github-project-automation bot moved this to Ready in NVIDIA Mar 14, 2026
@elvircrn
Copy link
Contributor

@tlrmchlsmth @leo-cf-tian

The trtllm scales issue appears for:

                    --max-cudagraph-capture-size 32768 \
                    --max-num-batched-tokens 32768 \

and switching to

                    --max-cudagraph-capture-size 8192 \
                    --max-num-batched-tokens 8192 \

made it go away.

Can confirm the int32/int64 index went away in both cases.

@tlrmchlsmth
Copy link
Member

thanks @elvircrn. I don't expect many people to set those variables so high, but could be nice to add a warning in case

Copy link
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to get this into v0.18.0, which cuts tomorrow. Could you please fix the pre-commit issues? Looks like they are caused by divergence from main

@wzhao18
Copy link
Contributor

wzhao18 commented Mar 15, 2026

I can help take a look tonight.

Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
@wzhao18
Copy link
Contributor

wzhao18 commented Mar 16, 2026

@tylertitsworth I fixed the merge conflicts. Can you start CI for this PR?

@mergify mergify bot removed the needs-rebase label Mar 16, 2026
@tlrmchlsmth tlrmchlsmth added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 16, 2026
Copy link
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we get

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wzhao18 could you hook up this kernel to CI?

needs to be added to .buildkite/test_areas/kernels.yaml

Copy link
Contributor

@wzhao18 wzhao18 Mar 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I thought I posted the following response but for some reason it was not submitted.

@tlrmchlsmth I re-examined the test and thought that this test may not be too meaningful to add here. It checks the result from _supports_parallel_config with some expectation that is derived from the function itself, which seems kind of redundant. Thus I removed the test from the PR.

I think test_modular_kernel_combinations_multigpu should be a unified test that ensures both that (1) _supports_parallel_config is set correctly and (2) the combination actually works (through testing). However, as far as I checked, this test is not in the CI pipeline and I am having some problems running it even in current main branch. I will look a bit more detail into this and potentially improve it in a future PR.

Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
@ywang96 ywang96 merged commit 2754231 into vllm-project:main Mar 16, 2026
64 of 65 checks passed
@github-project-automation github-project-automation bot moved this from Ready to Done in NVIDIA Mar 16, 2026
@github-project-automation github-project-automation bot moved this from Todo to Done in AMD Mar 16, 2026
elvircrn added a commit to elvircrn/vllm that referenced this pull request Mar 16, 2026
Squashed from vllm-project#36022.

Signed-off-by: Elvir Crncevic <elvircrn@gmail.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Lucaskabela pushed a commit to Lucaskabela/vllm that referenced this pull request Mar 17, 2026
Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
Signed-off-by: Leo Tian <lctian@nvidia.com>
Co-authored-by: wzhao18 <wzhao18.sz@gmail.com>
Co-authored-by: Stefano Castagnetta <scastagnetta@nvidia.com>
Co-authored-by: root <root@lyris0267.lyris.clusters.nvidia.com>
wendyliu235 pushed a commit to wendyliu235/vllm-public that referenced this pull request Mar 18, 2026
Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
Signed-off-by: Leo Tian <lctian@nvidia.com>
Co-authored-by: wzhao18 <wzhao18.sz@gmail.com>
Co-authored-by: Stefano Castagnetta <scastagnetta@nvidia.com>
Co-authored-by: root <root@lyris0267.lyris.clusters.nvidia.com>
fxdawnn pushed a commit to fxdawnn/vllm that referenced this pull request Mar 19, 2026
Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
Signed-off-by: Leo Tian <lctian@nvidia.com>
Co-authored-by: wzhao18 <wzhao18.sz@gmail.com>
Co-authored-by: Stefano Castagnetta <scastagnetta@nvidia.com>
Co-authored-by: root <root@lyris0267.lyris.clusters.nvidia.com>
yzh119 pushed a commit to flashinfer-ai/flashinfer that referenced this pull request Mar 24, 2026
…ted hidden state scale shape" for EP32+ configs (#2853)

## 📌 Description

Fix `int32` overflow in `trtllm_fp4_block_scale_moe` that causes a
misleading `NotImplementedError: Unsupported hidden state scale shape`
when deploying large Expert Parallel configurations (e.g., EP32 with
`DeepSeek-R1 NVFP4`).


**Step 1, NVFP4 activation quantization (per EP rank)**

Each of the 32 EP ranks quantizes its local activations via
`vllm.ops.scaled_fp4_quant` with `is_sf_swizzled_layout=False`. From
[nvfp4_quant_entry.cu](https://github.com/vllm-project/vllm/blob/a5e9d511defe2d2dc2dd270674fc197542fc0169/csrc/quantization/fp4/nvfp4_quant_entry.cu):
```cpp
output_sf = torch::empty(
    {m, n / CVT_FP4_SF_VEC_SIZE},
    torch::TensorOptions().device(device).dtype(torch::kUInt8));
```
For m=10240 (`max_num_batched_tokens`), n=7168 (`hidden_size`):

`hidden_states`: `[10240, 3584]` `uint8` (FP4 packed, 2 values per byte)
`hidden_states_scale`: `[10240, 448]` `uint8` → viewed as
`float8_e4m3fn`
No padding is applied in the non-swizzled layout. Scale numel = `10240 ×
448 = 4,587,520`.

**Step 2, EP allgather via dispatch()**

`MoEPrepareAndFinalizeNaiveDPEPModular.prepare()` in
[naive_dp_ep.py](https://github.com/vllm-project/vllm/blob/a5e9d511defe2d2dc2dd270674fc197542fc0169/vllm/model_executor/layers/fused_moe/prepare_finalize/naive_dp_ep.py)
calls `get_ep_group().dispatch()`, which allgathers both `hidden_states`
and `hidden_states_scale` (passed as `extra_tensors`) across all 32 EP
ranks:

`hidden_states`: `32 × [10240, 3584]` → [`327680, 3584]`
`hidden_states_scale`: `32 × [10240, 448]` → `[327680, 448]`

**Step 3, Scale reshape in vLLM wrapper**

In
[trtllm_nvfp4_moe.py](https://github.com/vllm-project/vllm/blob/a5e9d511defe2d2dc2dd270674fc197542fc0169/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py),
the scale is reshaped before passing to flashInfer:
```
hidden_states_scale=a1q_scale.view(torch.float8_e4m3fn).reshape(
    *hidden_states.shape[:-1], -1)  # → [327680, 448]
```
At this point `hidden_states_scale.numel()` = 327680 × 448 =
146,800,640.

**Step 4, int32 overflow in FlashInfer C++ kernel**

In `csrc/trtllm_fused_moe_kernel_launcher.cu`, the scale vector size is
computed as:
```cpp
int const num_tokens = hidden_states.size(0);   // int (32-bit) = 327680
int hidden_size = hidden_states.size(1);          // int (32-bit) = 3584
if (hidden_states.dtype() == dl_uint8) hidden_size *= 2;  // hidden_size = 7168
hidden_states_scale_vec_size =
    (num_tokens * hidden_size) / hidden_states_scale.value().numel();
//   ^^^^^^^^^^^^^^^^^^^^^^^^
//   int * int = int → OVERFLOW before promotion to int64 for division
```

the overflow:
`327680 × 7168 = 2,348,810,240`
`INT_MAX` = 2,147,483,647
2,348,810,240 > `INT_MAX`, signed int32 overflow (undefined behavior in
C++, wraps to -1,946,157,056 on two's complement architectures)

vec_size = -1,946,157,056 / 146,800,640 = -13
-13 ≠ 16 and -13 ≠ 32 will throws "Unsupported hidden state scale shape"

Step 5, why not and works

Overflow threshold for DeepSeek-R1 (hidden_size=7168):
Max safe tokens: INT_MAX / 7168 = 299,593
EP32 per-rank limit: 299,593 / 32 ≈ 9,362
Any max_num_batched_tokens > 9362 with EP32 will trigger the overflow

We confirmed the overflow boundary on an 8-node GB200 cluster (32 GPUs,
EP32, DP32) with --all2all-backend `allgather_reducescatter`:

| max_num_batched_tokens | Total tokens (×32) | M × 7168 | vs INT_MAX |
Result |
| :--- | :--- | :--- | :--- | :--- |
| 9360 | 299,520 | 2,146,560,000 | < 2,147,483,647 | ✅ Success |
| 9370 | 299,840 | 2,148,853,760 | > 2,147,483,647 | ❌ **Crash** |
| 8192 (Workaround) | 262,144 | 1,879,048,192 | < 2,147,483,647 | ✅
Success |
| 10240 (Original) | 327,680 | 2,348,810,240 | > 2,147,483,647 | ❌
**Crash** |


**Reproduction**
vLLM serve with EP32:
```
vllm serve nvidia/DeepSeek-R1-NVFP4 \
    --tensor-parallel-size 1 \
    --data-parallel-size 32 \
    --enable-expert-parallel \
    --all2all-backend allgather_reducescatter \
    --max-num-batched-tokens 10240 \
    --kv-cache-dtype fp8 \
    --trust-remote-code
```
Crashes during engine initialization with:
`NotImplementedError: Unsupported hidden state scale shape.` (Also found
this issue in
vllm-project/vllm#36022 (comment))



Promote the multiplication operands to int64_t before division to
prevent overflow:
`hidden_states_scale_vec_size`: Cast num_tokens to int64_t so the
multiplication chain executes in 64-bit.
`weight_scale_vec_size`: Apply the same pattern with local_num_experts
cast to int64_t, and declare the variable as int64_t for consistency.

Cast the multiplication operands to int64_t before the division:
```cpp
// In csrc/trtllm_fused_moe_kernel_launcher.cu
// Before (overflow-prone):
int const num_tokens = hidden_states.size(0);
int hidden_size = hidden_states.size(1);
if (hidden_states.dtype() == dl_uint8) hidden_size *= 2;
hidden_states_scale_vec_size =
    (num_tokens * hidden_size) / hidden_states_scale.value().numel();

// After (safe):
int const num_tokens = hidden_states.size(0);
int hidden_size = hidden_states.size(1);
if (hidden_states.dtype() == dl_uint8) hidden_size *= 2;
    hidden_states_scale_vec_size = (static_cast<int64_t>(num_tokens) * hidden_size) / hidden_states_scale.value().numel();
  }
```

The same pattern should also be applied to weight_scale_vec_size for
safety:
```cpp
int64_t weight_scale_vec_size =
    (static_cast<int64_t>(local_num_experts) * intermediate_size
     * intermediate_size_factor * hidden_size) /
    gemm1_weights_scale.numel();
```

**Impact**
Zero performance impact: these are CPU-side setup computations executed
once before GPU kernel launch.
Zero API change: No function signatures are modified.
Unblocks: EP32+ deployments for large-hidden-size models (DeepSeek-R1,
etc.) with max_num_batched_tokens above the int32 threshold.

**Environment**
Model: DeepSeek-R1-0528-FP4 (NVFP4, hidden_size=7168)
Hardware: 8× GB200 nodes (32 GPUs), disaggregated prefill-decode
Configuration: DP=32, EP=32, TP=1, PP=1
vLLM: v0.17.2rc1 (bundled FlashInfer)


## 🔍 Related Issues


## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Bug Fixes**
* Fixed integer overflow in internal size calculations that could cause
crashes or incorrect behavior with very large models or batch sizes,
improving stability and reliability for large-scale inference.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Albert Cheng (Engrg-Hardware 1) <albecheng@login-lyris01.lyris.clusters.nvidia.com>
khairulkabir1661 pushed a commit to khairulkabir1661/vllm that referenced this pull request Mar 27, 2026
Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
Signed-off-by: Leo Tian <lctian@nvidia.com>
Co-authored-by: wzhao18 <wzhao18.sz@gmail.com>
Co-authored-by: Stefano Castagnetta <scastagnetta@nvidia.com>
Co-authored-by: root <root@lyris0267.lyris.clusters.nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build documentation Improvements or additions to documentation frontend gpt-oss Related to GPT-OSS models multi-modality Related to multi-modality (#4194) new-model Requests to new models nvidia performance Performance-related issues qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm structured-output v1

Projects

Status: Done
Status: Done
Status: Done
Status: Done

Development

Successfully merging this pull request may close these issues.

10 participants